import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import GraphConvlution


class GCN(nn.Module):
    def __init__(self, n_features, n_hidden, n_classes, dropout):
        super(GCN, self).__init__()
        
        self.gc1 = GraphConvlution(n_features, n_hidden)
        self.gc2 = GraphConvlution(n_hidden, n_classes)
        self.dropout = dropout
    
    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return x